import os

from colossalqa.retrieval_conversation_universal import UniversalRetrievalConversation


def test_en_retrievalQA():
    data_path_en = os.environ.get('TEST_DATA_PATH_EN')
    data_path_zh = os.environ.get('TEST_DATA_PATH_ZH')
    en_model_path = os.environ.get('EN_MODEL_PATH')
    zh_model_path = os.environ.get('ZH_MODEL_PATH')
    zh_model_name = os.environ.get('ZH_MODEL_NAME')
    en_model_name = os.environ.get('EN_MODEL_NAME')
    sql_file_path = os.environ.get('SQL_FILE_PATH')
    qa_session = UniversalRetrievalConversation(files_en=[{
        'data_path': data_path_en,
        'name': 'company information',
        'separator': '\n'
    }],
                                                files_zh=[{
                                                    'data_path': data_path_zh,
                                                    'name': 'company information',
                                                    'separator': '\n'
                                                }],
                                                zh_model_path=zh_model_path,
                                                en_model_path=en_model_path,
                                                zh_model_name=zh_model_name,
                                                en_model_name=en_model_name,
                                                sql_file_path=sql_file_path)
    ans = qa_session.run("which company runs business in hotel industry?", which_language='en')
    print(ans)


def test_zh_retrievalQA():
    data_path_en = os.environ.get('TEST_DATA_PATH_EN')
    data_path_zh = os.environ.get('TEST_DATA_PATH_ZH')
    en_model_path = os.environ.get('EN_MODEL_PATH')
    zh_model_path = os.environ.get('ZH_MODEL_PATH')
    zh_model_name = os.environ.get('ZH_MODEL_NAME')
    en_model_name = os.environ.get('EN_MODEL_NAME')
    sql_file_path = os.environ.get('SQL_FILE_PATH')
    qa_session = UniversalRetrievalConversation(files_en=[{
        'data_path': data_path_en,
        'name': 'company information',
        'separator': '\n'
    }],
                                                files_zh=[{
                                                    'data_path': data_path_zh,
                                                    'name': 'company information',
                                                    'separator': '\n'
                                                }],
                                                zh_model_path=zh_model_path,
                                                en_model_path=en_model_path,
                                                zh_model_name=zh_model_name,
                                                en_model_name=en_model_name,
                                                sql_file_path=sql_file_path)
    ans = qa_session.run("哪家公司在经营酒店业务？", which_language='zh')
    print(ans)


if __name__ == "__main__":
    test_en_retrievalQA()
    test_zh_retrievalQA()
